<html><!-- Created using the cpp_pretty_printer from the dlib C++ library.  See http://dlib.net for updates. --><head><title>dlib C++ Library - cross_validate_track_association_trainer.h</title></head><body bgcolor='white'><pre>
<font color='#009900'>// Copyright (C) 2014  Davis E. King (davis@dlib.net)
</font><font color='#009900'>// License: Boost Software License   See LICENSE.txt for the full license.
</font><font color='#0000FF'>#ifndef</font> DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_H__
<font color='#0000FF'>#define</font> DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_H__

<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='cross_validate_track_association_trainer_abstract.h.html'>cross_validate_track_association_trainer_abstract.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='structural_track_association_trainer.h.html'>structural_track_association_trainer.h</a>"

<font color='#0000FF'>namespace</font> dlib
<b>{</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>namespace</font> impl
    <b>{</b>
        <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
            <font color='#0000FF'>typename</font> track_association_function,
            <font color='#0000FF'>typename</font> detection_type,
            <font color='#0000FF'>typename</font> label_type
            <font color='#5555FF'>&gt;</font>
        <font color='#0000FF'><u>void</u></font> <b><a name='test_track_association_function'></a>test_track_association_function</b> <font face='Lucida Console'>(</font>
            <font color='#0000FF'>const</font> track_association_function<font color='#5555FF'>&amp;</font> assoc,
            <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>labeled_detection<font color='#5555FF'>&lt;</font>detection_type,label_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> samples,
            <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&amp;</font> total_dets,
            <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>&amp;</font> correctly_associated_dets
        <font face='Lucida Console'>)</font>
        <b>{</b>
            <font color='#0000FF'>const</font> <font color='#0000FF'>typename</font> track_association_function::association_function_type<font color='#5555FF'>&amp;</font> f <font color='#5555FF'>=</font> assoc.<font color='#BB00BB'>get_assignment_function</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

            <font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> detection_type::track_type track_type;
            <font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> impl;

            dlib::rand rnd;
            std::vector<font color='#5555FF'>&lt;</font>track_type<font color='#5555FF'>&gt;</font> tracks;
            std::map<font color='#5555FF'>&lt;</font>label_type,<font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> track_idx; <font color='#009900'>// tracks[track_idx[id]] == track with ID id.
</font>
            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> j <font color='#5555FF'>=</font> <font color='#979000'>0</font>; j <font color='#5555FF'>&lt;</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>j<font face='Lucida Console'>)</font>
            <b>{</b>
                std::vector<font color='#5555FF'>&lt;</font>labeled_detection<font color='#5555FF'>&lt;</font>detection_type,label_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> dets <font color='#5555FF'>=</font> samples[j];
                <font color='#009900'>// Shuffle the order of the detections so we can be sure that there isn't
</font>                <font color='#009900'>// anything funny going on like the detections always coming in the same
</font>                <font color='#009900'>// order relative to their labels and the association function just gets
</font>                <font color='#009900'>// lucky by picking the same assignment ordering every time.  So this way
</font>                <font color='#009900'>// we know the assignment function really is doing something rather than
</font>                <font color='#009900'>// just being lucky.
</font>                <font color='#BB00BB'>randomize_samples</font><font face='Lucida Console'>(</font>dets, rnd<font face='Lucida Console'>)</font>;

                total_dets <font color='#5555FF'>+</font><font color='#5555FF'>=</font> dets.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
                std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font> assignments <font color='#5555FF'>=</font> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#BB00BB'>get_unlabeled_dets</font><font face='Lucida Console'>(</font>dets<font face='Lucida Console'>)</font>, tracks<font face='Lucida Console'>)</font>;
                std::vector<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>bool</u></font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>updated_track</font><font face='Lucida Console'>(</font>tracks.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, <font color='#979000'>false</font><font face='Lucida Console'>)</font>;
                <font color='#009900'>// now update all the tracks with the detections that associated to them.
</font>                <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; k <font color='#5555FF'>&lt;</font> assignments.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>k<font face='Lucida Console'>)</font>
                <b>{</b>
                    <font color='#009900'>// If the detection is associated to tracks[assignments[k]]
</font>                    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>assignments[k] <font color='#5555FF'>!</font><font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>
                    <b>{</b>
                        tracks[assignments[k]].<font color='#BB00BB'>update_track</font><font face='Lucida Console'>(</font>dets[k].det<font face='Lucida Console'>)</font>;
                        updated_track[assignments[k]] <font color='#5555FF'>=</font> <font color='#979000'>true</font>;

                        <font color='#009900'>// if this detection was supposed to go to this track
</font>                        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>track_idx.<font color='#BB00BB'>count</font><font face='Lucida Console'>(</font>dets[k].label<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> track_idx[dets[k].label]<font color='#5555FF'>=</font><font color='#5555FF'>=</font>assignments[k]<font face='Lucida Console'>)</font>
                            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>correctly_associated_dets;

                        track_idx[dets[k].label] <font color='#5555FF'>=</font> assignments[k];
                    <b>}</b>
                    <font color='#0000FF'>else</font>
                    <b>{</b>
                        track_type new_track;
                        new_track.<font color='#BB00BB'>update_track</font><font face='Lucida Console'>(</font>dets[k].det<font face='Lucida Console'>)</font>;
                        tracks.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>new_track<font face='Lucida Console'>)</font>;

                        <font color='#009900'>// if this detection was supposed to go to a new track
</font>                        <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>track_idx.<font color='#BB00BB'>count</font><font face='Lucida Console'>(</font>dets[k].label<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
                            <font color='#5555FF'>+</font><font color='#5555FF'>+</font>correctly_associated_dets;

                        track_idx[dets[k].label] <font color='#5555FF'>=</font> tracks.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>-</font><font color='#979000'>1</font>;
                    <b>}</b>
                <b>}</b>

                <font color='#009900'>// Now propagate all the tracks that didn't get any detections.
</font>                <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> k <font color='#5555FF'>=</font> <font color='#979000'>0</font>; k <font color='#5555FF'>&lt;</font> updated_track.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>k<font face='Lucida Console'>)</font>
                <b>{</b>
                    <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>updated_track[k]<font face='Lucida Console'>)</font>
                        tracks[k].<font color='#BB00BB'>propagate_track</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
                <b>}</b>
            <b>}</b>
        <b>}</b>
    <b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
        <font color='#0000FF'>typename</font> track_association_function,
        <font color='#0000FF'>typename</font> detection_type,
        <font color='#0000FF'>typename</font> label_type
        <font color='#5555FF'>&gt;</font>
    <font color='#0000FF'><u>double</u></font> <b><a name='test_track_association_function'></a>test_track_association_function</b> <font face='Lucida Console'>(</font>
        <font color='#0000FF'>const</font> track_association_function<font color='#5555FF'>&amp;</font> assoc,
        <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>labeled_detection<font color='#5555FF'>&lt;</font>detection_type,label_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> samples
    <font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> total_dets <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> correctly_associated_dets <font color='#5555FF'>=</font> <font color='#979000'>0</font>;

        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
        <b>{</b>
            impl::<font color='#BB00BB'>test_track_association_function</font><font face='Lucida Console'>(</font>assoc, samples[i], total_dets, correctly_associated_dets<font face='Lucida Console'>)</font>;
        <b>}</b>

        <font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>correctly_associated_dets<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>total_dets;
    <b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
    <font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font>
        <font color='#0000FF'>typename</font> trainer_type,
        <font color='#0000FF'>typename</font> detection_type,
        <font color='#0000FF'>typename</font> label_type
        <font color='#5555FF'>&gt;</font>
    <font color='#0000FF'><u>double</u></font> <b><a name='cross_validate_track_association_trainer'></a>cross_validate_track_association_trainer</b> <font face='Lucida Console'>(</font>
        <font color='#0000FF'>const</font> trainer_type<font color='#5555FF'>&amp;</font> trainer,
        <font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>labeled_detection<font color='#5555FF'>&lt;</font>detection_type,label_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> samples,
        <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> folds
    <font face='Lucida Console'>)</font>
    <b>{</b>
        <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> num_in_test  <font color='#5555FF'>=</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font color='#5555FF'>/</font>folds;
        <font color='#0000FF'>const</font> <font color='#0000FF'><u>long</u></font> num_in_train <font color='#5555FF'>=</font> samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>-</font> num_in_test;

        std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>labeled_detection<font color='#5555FF'>&lt;</font>detection_type,label_type<font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> <font color='#5555FF'>&gt;</font> samples_train;

        <font color='#0000FF'><u>long</u></font> next_test_idx <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> total_dets <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
        <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> correctly_associated_dets <font color='#5555FF'>=</font> <font color='#979000'>0</font>;

        <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'>&lt;</font> folds; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
        <b>{</b>
            samples_train.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;

            <font color='#009900'>// load up the training samples
</font>            <font color='#0000FF'><u>long</u></font> next <font color='#5555FF'>=</font> <font face='Lucida Console'>(</font>next_test_idx <font color='#5555FF'>+</font> num_in_test<font face='Lucida Console'>)</font><font color='#5555FF'>%</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> cnt <font color='#5555FF'>=</font> <font color='#979000'>0</font>; cnt <font color='#5555FF'>&lt;</font> num_in_train; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>cnt<font face='Lucida Console'>)</font>
            <b>{</b>
                samples_train.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>samples[next]<font face='Lucida Console'>)</font>;
                next <font color='#5555FF'>=</font> <font face='Lucida Console'>(</font>next <font color='#5555FF'>+</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
            <b>}</b>

            <font color='#0000FF'>const</font> track_association_function<font color='#5555FF'>&lt;</font>detection_type<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> df <font color='#5555FF'>=</font> trainer.<font color='#BB00BB'>train</font><font face='Lucida Console'>(</font>samples_train<font face='Lucida Console'>)</font>;
            <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> cnt <font color='#5555FF'>=</font> <font color='#979000'>0</font>; cnt <font color='#5555FF'>&lt;</font> num_in_test; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>cnt<font face='Lucida Console'>)</font>
            <b>{</b>
                impl::<font color='#BB00BB'>test_track_association_function</font><font face='Lucida Console'>(</font>df, samples[next_test_idx], total_dets, correctly_associated_dets<font face='Lucida Console'>)</font>;
                next_test_idx <font color='#5555FF'>=</font> <font face='Lucida Console'>(</font>next_test_idx <font color='#5555FF'>+</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font><font color='#5555FF'>%</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
            <b>}</b>
        <b>}</b>

        <font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>correctly_associated_dets<font color='#5555FF'>/</font><font face='Lucida Console'>(</font><font color='#0000FF'><u>double</u></font><font face='Lucida Console'>)</font>total_dets;
    <b>}</b>

<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<b>}</b>

<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_CROSS_VALIDATE_TRACK_ASSOCIATION_TrAINER_H__
</font>


</pre></body></html>